{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "E4QKXrRAbuo8"
   },
   "source": [
    "# TUTORIAL 1: Introduction to TorchGAN\n",
    "**Author** - [Aniket Das](https://aniket1998.github.io) & [Avik Pal](https://avik-pal.github.io)\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/torchgan/torchgan/blob/master/tutorials/Tutorial%201.%20Introduction%20to%20TorchGAN.ipynb)\n",
    "\n",
    "This tutorial introduces you to the basics of [TorchGAN](https://github.com/torchgan/torchgan) to define, train and evaluate **Generative Adversarial Networks** easily in \n",
    "**PyTorch**. This tutorial mainly explores the library's core features, the predefined losses and the models. TorchGAN is designed to be highly and very easily extendable and at the end of the tutorial, we see how seamlessly you can integrate your own custom losses and models into the API's training loop"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "kbAlenchb04M"
   },
   "source": [
    "This tutorial assumes that your system has **PyTorch** and **TorchGAN** installed properly. If not, the following code block will try to install the **latest tagged version** of TorchGAN. If you need to use some other version head over to the installation instructions on the [official documentation website](https://torchgan.readthedocs.io/en/latest/)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    import torchgan\n",
    "\n",
    "    print(f\"Existing TorchGAN {torchgan.__version__} installation found\")\n",
    "except ImportError:\n",
    "    import subprocess\n",
    "    import sys\n",
    "\n",
    "    subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"torchgan\"])\n",
    "    import torchgan\n",
    "\n",
    "    print(f\"Installed TorchGAN {torchgan.__version__}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "lySoUvC5b_I0"
   },
   "source": [
    "## IMPORTS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "MK5KQ8yLcAiB"
   },
   "outputs": [],
   "source": [
    "# General Imports\n",
    "import os\n",
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.animation as animation\n",
    "import numpy as np\n",
    "from IPython.display import HTML\n",
    "\n",
    "# Pytorch and Torchvision Imports\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torchvision\n",
    "from torch.optim import Adam\n",
    "import torch.nn as nn\n",
    "import torch.utils.data as data\n",
    "import torchvision.datasets as dsets\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.utils as vutils\n",
    "\n",
    "# Torchgan Imports\n",
    "import torchgan\n",
    "from torchgan.models import *\n",
    "from torchgan.losses import *\n",
    "from torchgan.trainer import Trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "colab_type": "code",
    "id": "KmILA2JycDHh",
    "outputId": "3e6cd571-8bb0-4100-f2ff-72e0f259a3cf"
   },
   "outputs": [],
   "source": [
    "# Set random seed for reproducibility\n",
    "manualSeed = 999\n",
    "random.seed(manualSeed)\n",
    "torch.manual_seed(manualSeed)\n",
    "print(\"Random Seed: \", manualSeed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "v3_Kpcicc02k"
   },
   "source": [
    "## DATA LOADING AND PREPROCESSING\n",
    "This tutorial uses the **MNIST** dataset for illustration purposes and convergence time issues. We apply the following transformations to the raw dataset to speed up training\n",
    "1. MNIST digits have a size of $28 \\times 28$ by default. The default models present in **TorchGAN** assume that the image dimensions are a perfect power of 2 (though this behavior can be very easily overriden). For the purposes of this tutorial where we use the default models out of the box, we rescale the images to $32 \\times 32$\n",
    "2. The images are normalized with a mean and standard deviation of 0.5 for each channel. This has been observed to enable easier training (one can also choose to normalize with the per channel mean and standard deviation)\n",
    "\n",
    "We then wrap the dataset in a **DataLoader**. This is done because **TorchGAN Trainer** , which shall be explored in later sections requires said DataLoader to be passed as a parameter while training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "4WKLUqWKcFdZ"
   },
   "outputs": [],
   "source": [
    "dataset = dsets.MNIST(\n",
    "    root=\"./mnist\",\n",
    "    train=True,\n",
    "    transform=transforms.Compose(\n",
    "        [\n",
    "            transforms.Resize((32, 32)),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize(mean=(0.5,), std=(0.5,)),\n",
    "        ]\n",
    "    ),\n",
    "    download=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "N-Jom_tpjYt0"
   },
   "outputs": [],
   "source": [
    "dataloader = data.DataLoader(dataset, batch_size=64, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "hmguhfJQnZep"
   },
   "source": [
    "## NETWORK ARCHITECTURE : DCGAN\n",
    "This tutorial uses the highly popular Deep Convolutional GAN or DCGAN architecture. **TorchGAN** provides a generalised implementation of DCGAN out of the box. \n",
    "\n",
    "### GENERATOR NETWORK\n",
    "The DCGAN Generator receives an input noise vector of size $batch\\ size \\times encoding\\ dims$. It outputs a tensor of $batch\\ size \\times 1 \\times 32 \\times 32$ corresponding to a batch of generated image samples. The generator transforms the noise vector into images in the following manner\n",
    "1. **Channel Dimension**: $encoding\\ dims \\rightarrow d \\rightarrow \\frac{d}{2} \\rightarrow \\frac{d}{4} \\rightarrow 1$.\n",
    "2. **Image size**: $(1 \\times 1) \\rightarrow (4 \\times 4) \\rightarrow (8 \\times 8) \\rightarrow (16 \\times 16) \\rightarrow (32 \\times 32)$.\n",
    "\n",
    "The intermediate layers use the **Leaky ReLU** activation function  as ReLU tends to kill gradients and critically slow down convergence. One can also use activation functions such as **ELU** or any other activation of choice that ensures good gradient flow. The last layer uses a $ tanh $ activation in order to constrain the pixel values in the range $(-1 \\to 1)$ . One can easily change the nonlinearity of the intermediate and the last layers as per their preference by passing them as parameters during initialization of the Generator object. \n",
    "\n",
    "### DISCRIMINATOR NETWORK\n",
    "The DCGAN discriminator has an architecture symmetric to the generator. It maps the image to a confidence score in order to classify whether the image is real (i.e. comes from the dataset) or fake (i. e. sampled by the generator)\n",
    "\n",
    "For reasons same as above we use a **Leaky ReLU** activation. The conversion of the image tensor to the confidence scores are as follows:\n",
    "\n",
    "1. **Channel Dimension**: $1 \\rightarrow d \\rightarrow 2 \\times d \\rightarrow 4 \\times d \\rightarrow 1$.\n",
    "2. **Image size**: $(32 \\times 32) \\rightarrow (16 \\times 16) \\rightarrow (8 \\times 8) \\rightarrow (4 \\times 4) \\rightarrow (1 \\times 1)$.\n",
    "\n",
    "*Note: The last layer of the discriminator in most standard implementations of DCGAN have a Sigmoid layer that causes the confidence scores to lie in the interval $(0 \\to 1)$ and allows the easy interpretation of the confidence score as the probability of the image being real. However, this interpretation is only restricted to the Minimax Loss proposed in the original GAN paper and losses such as the Wasserstein Loss do not require such an interpretation. If required, however, one can easily set the activation of the last layer to Sigmoid by passing it as a parameter during initialization time*"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "oYwYPJtalGSC"
   },
   "source": [
    "## OPTIMIZERS AND HYPERPARAMETERS\n",
    "\n",
    "The models, their corresponding optimizers and other hyperparameters like the nonlinearities to be used in the intermediate layers are bundled in the form of a dictionary and provided to the trainer for instantiation. The dictionary specifies the models that are to be trained, the optimizers associated with said models and learning rate schedulers, if any\n",
    "\n",
    "\n",
    "1. \"name\": The class name for the model. Generally a subclass of the ```torchgan.models.Generator``` or ```torchgan.models.Discriminator```\n",
    "2. \"args\": Arguments fed into the class during instantiation, into its constructor \n",
    "3. \"optimizer\": A dictionary containing the following key-value pairs defining the optimizer associated with the model\n",
    "    * \"name\" : The class name of the optimizer. Generally an optimizer from the ```torch.optim``` package\n",
    "    * \"args\" : Arguments to be fed to the optimizer during its instantiation, into its constructor\n",
    "    * \"var\": Variable name for the optimizer. This is an optional argument. If this is not provided, we assign the optimizer the name ```optimizer_{}``` where {} refers to the variable name of the model.\n",
    "    * \"scheduler\": Optional scheduler associated with the optimizer. Again this is a dictionary with the following keys\n",
    "        * \"name\" : Class name of the scheduler\n",
    "        * \"args\" : Arguments to be provided to the scheduler during instantiation, into its constructor\n",
    "        \n",
    "This tutorial shows the example for a DCGAN optimized by the Adam optimizer. Head over to the documentation of ```DCGANGenerator``` , ```DCGANDiscriminator``` or the ```torch.optim.Adam``` classes for more details about what each of the args mean (*NB: The args are basically parameters to the constructor of each class declared in \"name\" , as discussed before* ). Also try tinkering with the various hyperparameters like *\"encoding_dims\", \"step_channels\", \"nonlinearity\" and \"last_nonlinearity\"*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "dXQHfhPsja1o"
   },
   "outputs": [],
   "source": [
    "dcgan_network = {\n",
    "    \"generator\": {\n",
    "        \"name\": DCGANGenerator,\n",
    "        \"args\": {\n",
    "            \"encoding_dims\": 100,\n",
    "            \"out_channels\": 1,\n",
    "            \"step_channels\": 32,\n",
    "            \"nonlinearity\": nn.LeakyReLU(0.2),\n",
    "            \"last_nonlinearity\": nn.Tanh(),\n",
    "        },\n",
    "        \"optimizer\": {\"name\": Adam, \"args\": {\"lr\": 0.0001, \"betas\": (0.5, 0.999)}},\n",
    "    },\n",
    "    \"discriminator\": {\n",
    "        \"name\": DCGANDiscriminator,\n",
    "        \"args\": {\n",
    "            \"in_channels\": 1,\n",
    "            \"step_channels\": 32,\n",
    "            \"nonlinearity\": nn.LeakyReLU(0.2),\n",
    "            \"last_nonlinearity\": nn.LeakyReLU(0.2),\n",
    "        },\n",
    "        \"optimizer\": {\"name\": Adam, \"args\": {\"lr\": 0.0003, \"betas\": (0.5, 0.999)}},\n",
    "    },\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "PAlshpjPqwU8"
   },
   "source": [
    "## LOSS FUNCTIONS\n",
    "**TorchGAN** provides a wide variety of GAN losses and penalties off the shelf. One can also easily implement custom losses and integrate it with the highly robust training pipeline\n",
    "\n",
    "1. Minimax Loss\n",
    "2. Wasserstein GAN with Gradient Penalty\n",
    "3. Least Squares GAN or LSGAN\n",
    "\n",
    "\n",
    "The loss objects to be used by the trainer are added in a list as shown"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "598HhE5Sqs3x"
   },
   "outputs": [],
   "source": [
    "minimax_losses = [MinimaxGeneratorLoss(), MinimaxDiscriminatorLoss()]\n",
    "wgangp_losses = [\n",
    "    WassersteinGeneratorLoss(),\n",
    "    WassersteinDiscriminatorLoss(),\n",
    "    WassersteinGradientPenalty(),\n",
    "]\n",
    "lsgan_losses = [LeastSquaresGeneratorLoss(), LeastSquaresDiscriminatorLoss()]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "bcHL1oqgyBRD"
   },
   "source": [
    "## VISUALIZE THE TRAINING DATA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 498
    },
    "colab_type": "code",
    "id": "HMkI_5XQxt_Q",
    "outputId": "fc417ed9-e5c2-4026-93b1-635db1e7edc9"
   },
   "outputs": [],
   "source": [
    "# Plot some of the training images\n",
    "real_batch = next(iter(dataloader))\n",
    "plt.figure(figsize=(8, 8))\n",
    "plt.axis(\"off\")\n",
    "plt.title(\"Training Images\")\n",
    "plt.imshow(\n",
    "    np.transpose(\n",
    "        vutils.make_grid(real_batch[0][:64], padding=2, normalize=True).cpu(), (1, 2, 0)\n",
    "    )\n",
    ")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "CEt29MyPyRlE"
   },
   "source": [
    "## TRAINING THE NETWORK\n",
    "\n",
    "*NB: Training the models are quite expensive. Hence we will train the models for **10** epochs if a GPU is available, else we will be training for only **5** epochs. We recommend using the **GPU runtime** in Colab. The images will not look even close to realistic in **5** epochs but shall be enough to show that it is learning to generate good quality images. If you have access to powerful GPUs or want to see realistic samples, I would recommend simply increasing the **epochs** variable in the next code block*\n",
    "\n",
    "---\n",
    "\n",
    "The trainer is initialized by passing the network descriptors and the losses, and then calling the trainer on the dataset. \n",
    "The *sample_size* parameter decides how many images to sample for visualization at the end of every epoch, and the *epochs* parameter decides the number of epochs.\n",
    "We illustrate the training process by training an LSGAN. Simply change the losses list passed from *lsgan_losses* to *wgangp_losses* to train a Wasserstein GAN with Gradient Penalty, or to *minimax_losses* to train a Minimax GAN\n",
    "\n",
    "---\n",
    "\n",
    "Important information for visualizing the performance of the GAN is printed to the console. We also provide a **Visualizer API** for visualizing the various losses, gradient flow and generated images. Setting up the Visualizer using either a **TensorboardX** or **Vizdom** backend is the recommended approach for visualizing the training process. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 53
    },
    "colab_type": "code",
    "id": "Szgznt54yHaI",
    "outputId": "3078e59d-44aa-4ee6-9f74-cea7c909b187"
   },
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda:0\")\n",
    "    # Use deterministic cudnn algorithms\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    epochs = 10\n",
    "else:\n",
    "    device = torch.device(\"cpu\")\n",
    "    epochs = 5\n",
    "\n",
    "print(\"Device: {}\".format(device))\n",
    "print(\"Epochs: {}\".format(epochs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "5ai_Vm2a0ecx"
   },
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    dcgan_network, lsgan_losses, sample_size=64, epochs=epochs, device=device\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 5795
    },
    "colab_type": "code",
    "id": "ig1cT2Ve1hzB",
    "outputId": "72745ce1-1cb9-42cb-bbd0-42d59f1f7979"
   },
   "outputs": [],
   "source": [
    "trainer(dataloader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "yhu0fhm7EY3m"
   },
   "source": [
    "## VISUALIZING THE SAMPLES\n",
    "Once training is complete, one can easily visualize the loss curves, gradient flow and sampled images per epoch on either the **TensorboardX** or **Vizdom** backends. For the purposes of this tutorial, we plot some of the sampled images here itself.\n",
    "\n",
    "*NB: It is highly recommended to view the results on TensorboardX or Vizdom if you are running this tutorial locally*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 317
    },
    "colab_type": "code",
    "id": "Dm8Husvr4aA1",
    "outputId": "634e893e-b98b-466f-9cfa-c7a396a13f51"
   },
   "outputs": [],
   "source": [
    "# Grab a batch of real images from the dataloader\n",
    "real_batch = next(iter(dataloader))\n",
    "\n",
    "# Plot the real images\n",
    "plt.figure(figsize=(10, 10))\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.axis(\"off\")\n",
    "plt.title(\"Real Images\")\n",
    "plt.imshow(\n",
    "    np.transpose(\n",
    "        vutils.make_grid(\n",
    "            real_batch[0].to(device)[:64], padding=5, normalize=True\n",
    "        ).cpu(),\n",
    "        (1, 2, 0),\n",
    "    )\n",
    ")\n",
    "\n",
    "# Plot the fake images from the last epoch\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.axis(\"off\")\n",
    "plt.title(\"Fake Images\")\n",
    "plt.imshow(plt.imread(\"{}/epoch{}_generator.png\".format(trainer.recon, trainer.epochs)))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1137
    },
    "colab_type": "code",
    "id": "l5fxuZC3FJG2",
    "outputId": "df798b64-421c-47b5-ccc8-cb4ed1caa028"
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(8, 8))\n",
    "plt.axis(\"off\")\n",
    "ims = [\n",
    "    [plt.imshow(plt.imread(\"{}/epoch{}_generator.png\".format(trainer.recon, i)))]\n",
    "    for i in range(1, trainer.epochs + 1)\n",
    "]\n",
    "ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)\n",
    "\n",
    "# Play the animation\n",
    "HTML(ani.to_jshtml())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "ImDvtrtaFNdk"
   },
   "source": [
    "## TRAINING CONDITIONAL GAN MODELS\n",
    "TorchGAN supports semi-supervised learning off the shelf through its Conditional GAN and Auxiliary Classifier GAN models and losses, that condition on the labels. Note that it is mandatory for the dataset to have labels for semi-supervised learning. We illustrate this by training a **Conditional DCGAN** on **MNIST**, conditioning the model on the identity of the digit\n",
    "\n",
    "Generator and Discriminator architecture remain the same as that in DCGAN except the number of class labels has to be passed as an additional parameter in the dictionary defining the model.We reuse all the hyperparameters from the previous section"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "WePwyU73FJuR"
   },
   "outputs": [],
   "source": [
    "cgan_network = {\n",
    "    \"generator\": {\n",
    "        \"name\": ConditionalGANGenerator,\n",
    "        \"args\": {\n",
    "            \"encoding_dims\": 100,\n",
    "            \"num_classes\": 10,  # MNIST digits range from 0 to 9\n",
    "            \"out_channels\": 1,\n",
    "            \"step_channels\": 32,\n",
    "            \"nonlinearity\": nn.LeakyReLU(0.2),\n",
    "            \"last_nonlinearity\": nn.Tanh(),\n",
    "        },\n",
    "        \"optimizer\": {\"name\": Adam, \"args\": {\"lr\": 0.0001, \"betas\": (0.5, 0.999)}},\n",
    "    },\n",
    "    \"discriminator\": {\n",
    "        \"name\": ConditionalGANDiscriminator,\n",
    "        \"args\": {\n",
    "            \"num_classes\": 10,\n",
    "            \"in_channels\": 1,\n",
    "            \"step_channels\": 32,\n",
    "            \"nonlinearity\": nn.LeakyReLU(0.2),\n",
    "            \"last_nonlinearity\": nn.Tanh(),\n",
    "        },\n",
    "        \"optimizer\": {\"name\": Adam, \"args\": {\"lr\": 0.0003, \"betas\": (0.5, 0.999)}},\n",
    "    },\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "sxaQhUhcL-4o"
   },
   "source": [
    "## LOSS FUNCTIONS\n",
    "We reuse the Least Squares loss used to train the DCGAN in the previous section "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "gCbWrB-7LMy5"
   },
   "outputs": [],
   "source": [
    "trainer_cgan = Trainer(\n",
    "    cgan_network, lsgan_losses, sample_size=64, epochs=epochs, device=device\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 449
    },
    "colab_type": "code",
    "id": "1bp86TPYLmfZ",
    "outputId": "803e10e8-a51f-420d-c2cf-c9d301bc1d49"
   },
   "outputs": [],
   "source": [
    "trainer_cgan(dataloader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "NH3QG6CrMO-d"
   },
   "source": [
    "## VISUALIZING THE SAMPLES"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ZcjulvKkL6kx"
   },
   "outputs": [],
   "source": [
    "# Grab a batch of real images from the dataloader\n",
    "real_batch = next(iter(dataloader))\n",
    "\n",
    "# Plot the real images\n",
    "plt.figure(figsize=(10, 10))\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.axis(\"off\")\n",
    "plt.title(\"Real Images\")\n",
    "plt.imshow(\n",
    "    np.transpose(\n",
    "        vutils.make_grid(\n",
    "            real_batch[0].to(device)[:64], padding=5, normalize=True\n",
    "        ).cpu(),\n",
    "        (1, 2, 0),\n",
    "    )\n",
    ")\n",
    "\n",
    "# Plot the fake images from the last epoch\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.axis(\"off\")\n",
    "plt.title(\"Fake Images\")\n",
    "plt.imshow(plt.imread(\"{}/epoch{}_generator.png\".format(trainer.recon, trainer.epochs)))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "cVO8FwGPMV2t"
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(8, 8))\n",
    "plt.axis(\"off\")\n",
    "ims = [\n",
    "    [plt.imshow(plt.imread(\"{}/epoch{}_generator.png\".format(trainer.recon, i)))]\n",
    "    for i in range(1, trainer.epochs + 1)\n",
    "]\n",
    "ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)\n",
    "\n",
    "# Play the animation\n",
    "HTML(ani.to_jshtml())"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "Introduction To TorchGAN.ipynb",
   "provenance": [],
   "toc_visible": true,
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
